import numpy as np

from src.clustering.base_distances_bis import (
    norm_hamming,
    jaccard,
    rms_jaccard,
    geom_jaccard,
    rms_hamming,
    geom_hamming,
)
from src.converters.convert_votes_to_vecs import (
    convert_profiles_to_vec_profiles
)
from src.imports.import_pabulib import import_pabulib_files_from_folder

base_distances = {
    'jaccard': jaccard,
    'geom_jaccard': geom_jaccard,
    'norm_hamming': norm_hamming,
    'rms_jaccard': rms_jaccard,
    'geom_hamming': geom_hamming,
    'rms_hamming': rms_hamming,
}

if __name__ == "__main__":

    targets = ['warszawa_2023_districts','warszawa_2024_districts',
               'krakow_2023_districts', 'krakow_2024_districts',
                'lodz_2023_districts', 'lodz_2024_districts']

    for target in targets:

        instances, profiles = import_pabulib_files_from_folder(f'data/pabulib/{target}')

        vec_profiles, names = convert_profiles_to_vec_profiles(profiles)


        stat_num_cat = []
        stat_num_proj = []
        stat_num_pc = []
        stat_num_votes = []

        for instance_id in instances:

            num_projects = len(vec_profiles[instance_id])
            stat_num_proj.append(num_projects)

            P = np.array(vec_profiles[instance_id])
            P = np.transpose(P)
            value = len(P)
            stat_num_votes.append(value)

            ALL_INTRA = {base_distance: {} for base_distance in base_distances}
            ALL_INTER = {base_distance: {} for base_distance in base_distances}
            ALL_CLOSEST = {base_distance: {} for base_distance in base_distances}


            intra_party_distances = []
            inter_party_distances = []
            closest_friend_ratio = [0 for _ in range(num_projects)]

            id_0 = names[instance_id][0]
            if 'categories' not in instances[instance_id].project_meta[id_0]:
                break

            # get list of all categories
            categories = dict()
            for i in range(num_projects):
                i_id = names[instance_id][i]
                i_cat = instances[instance_id].project_meta[i_id]['categories']
                for c in i_cat:
                    if c not in set(categories.keys()):
                        categories[c] = 1
                    else:
                        categories[c] += 1


            # print number of categories with at least 2 projects
            value = len([c for c in categories if categories[c] > 1])
            stat_num_cat.append(value)

            values = [c for c in categories.values()]
            values = np.array(values)
            value = np.mean(values)
            stat_num_pc.append(value)

        print(target)
        stat_num_cat = np.array(stat_num_cat)
        stat_num_proj = np.array(stat_num_proj)
        stat_num_pc = np.array(stat_num_pc)
        stat_num_votes = np.array(stat_num_votes)

        print(f"Number of instances: {len(instances)}")
        print(f"Number of voters: {round(stat_num_votes.mean(),2)} +/- {round(stat_num_votes.std(),2)}")
        print(f"Number of projects: {round(stat_num_proj.mean(),2)} +/- {round(stat_num_proj.std(),2)}")
        print(f"Number of categories: {round(stat_num_cat.mean(),2)} +/- {round(stat_num_cat.std(),2)}")
        print(f"Number of proj. per cat: {round(stat_num_pc.mean(),2)} +/- {round(stat_num_pc.std(),2)}")

